import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
import numpy as np
from IPython import embed

class DataSet(Dataset):
    def __init__(self, text_path, train_split=0.8):

        super(DataSet, self).__init__()
        self.data = open(text_path).read().split("\n\n")
        if len(self.data[-1]) == 0: self.data.pop()

        self.train_split = train_split

        self.winner_first = np.random.random(len(self.data)) > 0.5
        
    def __getitem__(self, index):
        data = self.data[index].split("\n")
        score1, score2 = data[:2]
        player_stats = data[2:]
        names = np.array([int(x.split(";")[0]) for x in player_stats])
        winner = np.array([x.split(";")[-1] for x in player_stats]) == "Win"
        loser = np.invert(winner)

        if self.winner_first[index]: 
            return torch.tensor(names[winner]), torch.tensor(names[loser]), torch.ones(1)
        else:
            return torch.tensor(names[loser]), torch.tensor(names[winner]), torch.zeros(1)

    def __len__(self):
        return len(self.data)

    def get_split_indices(self):
        indices = list(range(self.__len__()))
        np.random.shuffle(indices)
        cutoff = int(np.floor(self.train_split * self.__len__()))
        train_indices, val_indices = indices[:cutoff], indices[cutoff:]
        return train_indices, val_indices

class Single_DataSet(Dataset):
    def __init__(self, text_path, num_players, split=[0.7, 0.1, 0.2]):
        super(Single_DataSet, self).__init__()

        self.data = open(text_path).read().split("\n\n")
        if len(self.data[-1]) == 0: self.data.pop()

        self.split = split 
        self.num_players = num_players

    def __getitem__(self, index, normalized=False):
        data = self.data[index].split()
        score = float(data[-1])
        feats = [int(indicator) for indicator in data[:-1]]
        norm = sum(feats) if sum(feats) > 0 and normalized else 1

        return torch.tensor(feats), torch.tensor([score / norm]) 

    def __len__(self):
        return len(self.data)

    def get_split_indices(self):
        indices = list(range(self.__len__()))
        np.random.shuffle(indices)
        train_perc, val_perc, test_perc = self.split

        train_cutoff = int(np.floor(train_perc * self.__len__()))
        val_cutoff = int(np.floor(val_perc * self.__len__()))
        train_indices, val_indices, test_indices  = indices[:train_cutoff], indices[train_cutoff:train_cutoff + val_cutoff], indices[train_cutoff + val_cutoff:]
        return train_indices, val_indices, test_indices

